{% comment %}
// vim: set syntax=asm :
/* mmm 40 x 5:

    ymm0 ymm5
    ymm1 ymm6
    ymm2 ymm7
    ymm3 ymm8
    ymm4 ymm9

System V ABI:
    args: rdi, rsi, rdx, rcx, r8, r9
    preserve: rbx, rsp, rbp, r12, r13, r14, r15
    scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11
    return: rax (+rdx)

Windows ABI:
    args: RCX, RDX, R8, R9
    preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15
    scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15
    return: rax (+rdx)
*/
{% endcomment %}

{% include "preamble.tmpliq" type:"f32", size:"40x2", suffix:suffix, G:G %}

{{L}}clear:
    vzeroall
    jmp     {{L}}non_linear_loop

{{L}}add_mat_mul:
    mov     rcx,    [rdi + 24]   // B
    mov     rax,    [rdi + 16]   // A

    mov     rbx,    [rdi + 8]    // k
    test    rbx,    rbx
    jz      {{L}}non_linear_loop

{{L}}main_loop_packed_packed:
    {% include "5x2/packed_packed_loop1/avx.tmpli" %}

    dec             rbx
    jnz             {{L}}main_loop_packed_packed

    jmp             {{L}}non_linear_loop

// NON LINEAR / ADDC

{% include "fma_mmm_f32_scalars.tmpliq" from:0, to:9, type:"f32" %}
{% include "fma_mmm_f32_per_rows.tmpliq" mr:40, from:0, to:9, type:"f32" %}
{% include "fma_mmm_f32_per_cols.tmpliq" mr:40, from:0, to:9, type:"f32" %}
{% include "fma_mmm_load_tile.tmpliq" from:0, to:9 %}

{{L}}add_unicast:
    mov     r8,    [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride

    cmp rsi, 4
    jne {{L}}unicast_generic

    lea             r9,  [ r8 + rbx ]
    lea             r10, [ r9 + rbx]
    lea             r11, [ r10 + rbx ]
    lea             r12, [ r11 + rbx ]


{% for col in (0..1) %}
    {% for row in (0..4) %}
        vmovups ymm12,  [ r{{col|plus:8}} ]
        add		r{{col|plus:8}}, 32
        vaddps 	ymm{{col|times:5|plus:row}}, ymm{{col|times:5|plus:row}}, ymm12
    {% endfor %}
{% endfor %}
    jmp    {{L}}non_linear_loop

{{L}}unicast_generic:
    mov     eax,    0
{% for i in (0..3) %}
    pinsrd  xmm14, eax, {{i}}
    add     eax,    esi
{% endfor %}
{% for i in (0..3) %}
    pinsrd  xmm15, eax, {{i}}
    add     eax,    esi
{% endfor %}

    vperm2f128      ymm14,  ymm14, ymm15,         32 // ymm14 <- xmm14::xmm15

    lea             r9,  [ r8 + rsi * 8]
    lea             r10, [ r9 + rsi * 8]
    lea             r11, [ r10 + rsi * 8]
    lea             r12, [ r11 + rsi * 8]

{% for col in (0..1) %}
   {% for row in (0..4) %}
      vpcmpeqd        ymm15, ymm15, ymm15
      vgatherdps      ymm12, [ r{{row|plus:8}} + ymm14 ], ymm15
      add 			  r{{row|plus:8}}, 	rbx
      vaddps 		  ymm{{col|times:5|plus:row}}, ymm{{col|times:5|plus:row}}, ymm12
   {% endfor %}
{% endfor %}

    jmp    {{L}}non_linear_loop

{{L}}add_row_col_products:
    mov             rax, [ rdi + 8 ]
    mov             rbx, [ rdi + 16 ]

    vbroadcastss    ymm10, dword ptr [rbx]
    vbroadcastss    ymm11, dword ptr [rbx + 4]
{% for i in (0..4) %}
    vmovups         ymm12,  [rax + {{i|times:32}}]
    vfmadd231ps     ymm{{0|plus:i}}, ymm12, ymm10
    vfmadd231ps     ymm{{5|plus:i}}, ymm12, ymm11
{% endfor %}
    jmp    {{L}}non_linear_loop


{{L}}store:
    mov     r8,     [rdi + 8]           // c ptr
    mov     rsi,    [rdi + 16]          // row stride
    mov     rbx,    [rdi + 24]          // col stride

    lea     r9,     [ r8  +     rbx ]
    lea     r10,    [ r8  + 2 * rbx ]
    lea     r11,    [ r10 +     rbx ]
    lea     r12,    [ r10 + 2 * rbx ]

    cmp         rsi, 4
    jne         {{L}}store_strides_generic

    {% for col in (0..1) %}
       {% for row in (0..4) %}
            vmovups ymmword ptr [r{{col|plus:8}}], ymm{{col|times:5|plus:row}}
            add 	r{{col|plus:8}}, 32
       {% endfor %}
    {% endfor %}

    jmp     {{L}}non_linear_loop

{{L}}store_strides_generic:
    {% for col in (0..1) %}
       {% for row in (0..4) %}
           {% for i in (0..3) %}
                vextractps  dword ptr [r{{col | plus: 8}}], xmm{{col | times:5 | plus:row}}, {{i}}
                add         r{{col | plus: 8}}, rsi
           {% endfor %}
           vperm2f128  ymm{{col | times:5 | plus:row}}, ymm{{col | times:5 | plus:row}}, ymm{{col | times:5 | plus:row}}, 1
           {% for i in (0..3) %}
                vextractps  dword ptr [r{{col | plus: 8}}], xmm{{col | times:5|plus:row}}, {{i}}
                add         r{{col | plus: 8}}, rsi
           {% endfor %}
       {% endfor %}
    {% endfor %}
    jmp     {{L}}non_linear_loop

{% include "postamble.tmpliq" type:"f32", size:"40x2", suffix:suffix, G:G, L:L %}
